{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import *\n",
    "from utils import *\n",
    "\n",
    "import os, sys, time, datetime, random\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "from torch.autograd import Variable\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path='config/yolov3.cfg'\n",
    "weights_path='config/yolov3.weights'\n",
    "class_path='config/coco.names'\n",
    "img_size=416\n",
    "conf_thres=0.8\n",
    "nms_thres=0.4\n",
    "\n",
    "# Load model and weights\n",
    "model = Darknet(config_path, img_size=img_size)\n",
    "model.load_weights(weights_path)\n",
    "model.cuda()\n",
    "model.eval()\n",
    "classes = utils.load_classes(class_path)\n",
    "Tensor = torch.cuda.FloatTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def detect_image(img):\n",
    "    # scale and pad image\n",
    "    ratio = min(img_size/img.size[0], img_size/img.size[1])\n",
    "    imw = round(img.size[0] * ratio)\n",
    "    imh = round(img.size[1] * ratio)\n",
    "    img_transforms = transforms.Compose([ transforms.Resize((imh, imw)),\n",
    "         transforms.Pad((max(int((imh-imw)/2),0), max(int((imw-imh)/2),0), max(int((imh-imw)/2),0), max(int((imw-imh)/2),0)),\n",
    "                        (128,128,128)),\n",
    "         transforms.ToTensor(),\n",
    "         ])\n",
    "    # convert image to Tensor\n",
    "    image_tensor = img_transforms(img).float()\n",
    "    image_tensor = image_tensor.unsqueeze_(0)\n",
    "    input_img = Variable(image_tensor.type(Tensor))\n",
    "    # run inference on the model and get detections\n",
    "    with torch.no_grad():\n",
    "        detections = model(input_img)\n",
    "        detections = utils.non_max_suppression(detections, 80, conf_thres, nms_thres)\n",
    "    return detections[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "videopath = '../data/video/overpass.mp4'\n",
    "\n",
    "%pylab inline \n",
    "import cv2\n",
    "from IPython.display import clear_output\n",
    "\n",
    "cmap = plt.get_cmap('tab20b')\n",
    "colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]\n",
    "\n",
    "# initialize Sort object and video capture\n",
    "from sort import *\n",
    "vid = cv2.VideoCapture(videopath)\n",
    "mot_tracker = Sort() \n",
    "\n",
    "#while(True):\n",
    "for ii in range(40):\n",
    "    ret, frame = vid.read()\n",
    "    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
    "    pilimg = Image.fromarray(frame)\n",
    "    detections = detect_image(pilimg)\n",
    "\n",
    "    img = np.array(pilimg)\n",
    "    pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))\n",
    "    pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))\n",
    "    unpad_h = img_size - pad_y\n",
    "    unpad_w = img_size - pad_x\n",
    "    if detections is not None:\n",
    "        tracked_objects = mot_tracker.update(detections.cpu())\n",
    "\n",
    "        unique_labels = detections[:, -1].cpu().unique()\n",
    "        n_cls_preds = len(unique_labels)\n",
    "        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:\n",
    "            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])\n",
    "            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])\n",
    "            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])\n",
    "            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])\n",
    "\n",
    "            color = colors[int(obj_id) % len(colors)]\n",
    "            color = [i * 255 for i in color]\n",
    "            cls = classes[int(cls_pred)]\n",
    "            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h), color, 4)\n",
    "            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60, y1), color, -1)\n",
    "            cv2.putText(frame, cls + \"-\" + str(int(obj_id)), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 3)\n",
    "\n",
    "    fig=figure(figsize=(12, 8))\n",
    "    title(\"Video Stream\")\n",
    "    imshow(frame)\n",
    "    show()\n",
    "    clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
